import pandas as pd
import numpy as np
from torch import nn
from torch.optim import Adam
from tqdm import tqdm
from transformers import BertTokenizer, BertTokenizerFast
from transformers import BertModel, AutoModel, AutoModelForMaskedLM
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, RobertaModel
import logging


# BertDNN
class BertDNN(nn.Module):
    def __init__(self, config):
        super(BertDNN, self).__init__()
        self.bert = AutoModel.from_pretrained(config['pre_trained_model'])
        self.fc1 = nn.Linear(config["embed_size"], config["hidden_size_dnn_1"])
        self.drop1 = nn.Dropout(config["dropout_dnn_1"])
        self.fc2 = nn.Linear(config["hidden_size_dnn_1"], config["hidden_size_dnn_2"])
        self.drop2 = nn.Dropout(config["dropout_dnn_2"])
        self.fc = nn.Linear(config["hidden_size_dnn_2"], config["class_num"])

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state_cls = outputs[0][:, 0, :]                  # (batch_size, bert_emb_size)
        out = F.relu(self.drop1(self.fc1(last_hidden_state_cls)))
        out = F.relu(self.drop2(self.fc2(out)))
        logits = self.fc(out)
        return logits



# Bert_BiLSTM
class Bert_BiLSTM(nn.Module):
    def __init__(self, config):
        super(Bert_BiLSTM, self).__init__()
        self.bert = BertModel.from_pretrained(config['pre_trained_model'])
        self.lstm = nn.LSTM(input_size=config["embed_size"],
                            hidden_size=config["hidden_size_lstm"],
                            num_layers=config["num_layers_lstm"],
                            bidirectional=True,
                            batch_first=True)

        # Instantiate a feed-forward classifier
        D_in, H, D_out = config["hidden_size_lstm"] * 4, config["fc_hidden_size_lstm"], config["class_num"]
        self.classifier = nn.Sequential(
            nn.Linear(D_in, H),
            nn.Dropout(config['dropout']),
            nn.ReLU(),
            nn.Linear(H, D_out)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs[0]                                          # (batch_size, seq_len, emb_size)
        out, _ = self.lstm(last_hidden_state)                                   # (batch_size, seq_len, 2 * hidden_size)
        # ����first time step��last time step��output��Ϊ�������
        out = torch.cat((out[:, 0, :], out[:, -1, :]), -1)                      # (batch_size, 4 * hidden_size)
        logits = self.classifier(out)
        return logits


# Bert_LstmAtt
class Bert_LstmAtt(nn.Module):
    def __init__(self, config):
        super(Bert_LstmAtt, self).__init__()
        self.bert = BertModel.from_pretrained(config['pre_trained_model'])
        self.lstm = nn.LSTM(input_size=config['embed_size'],
                            hidden_size=config['hidden_size_lstm'],
                            num_layers=config['num_layers_lstm'],
                            bidirectional=True,
                            batch_first=True)

        hidden_dim = config['hidden_size_lstm'] * 2
        self.W = nn.Parameter(torch.randn(hidden_dim)) # ��ʼ��Attention����

        self.fc = nn.Linear(hidden_dim, config['class_num'])
        self.dropout = nn.Dropout(config['dropout'])

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        x_embed = outputs[0]                                                  # (batch_size, seq_len, emb_size)
        out, _ = self.lstm(x_embed)                                           # (batch_size, seq_len, 2 * hidden_size)
        out = torch.tanh(out)
        # attention
        alpha = F.softmax(torch.matmul(out, self.W), dim=1).unsqueeze(-1)     # (batch_size, seq_len, 1)
        out = out * alpha                                                     # (batch_size, seq_len, 2 * hidden_size)
        out = torch.sum(out, dim=1)                                           # (batch_size, 2 * hidden_size)
        logits = self.fc(self.dropout(out))                                   # (batch_size, class_num)
        return logits


# TextCNN
class GlobalMaxPool1d(nn.Module):
    def __init__(self):
        super(GlobalMaxPool1d, self).__init__()

    def forward(self, x):
        # x shape: (batch_size, channel, seq_len)
        # return shape: (batch_size, channel, 1)
        return F.max_pool1d(x, kernel_size=x.shape[2])


# Bert_TextCNN
class Bert_TextCNN(nn.Module):
    def __init__(self, config):
        super(Bert_TextCNN, self).__init__()
        self.bert = BertModel.from_pretrained(config['pre_trained_model'])
        self.dropout = nn.Dropout(config['dropout'])
        self.fc = nn.Linear(sum(config['num_channels']), config['class_num'])

        # ʱ�����ػ���û��Ȩ�أ����Կ��Թ���һ��ʵ��
        self.pool = GlobalMaxPool1d()
        self.convs = nn.ModuleList()  # �������һά������
        for c, k in zip(config['num_channels'], config['kernel_sizes']):
            self.convs.append(nn.Conv1d(in_channels=config['embed_size'],
                                        out_channels=c,
                                        kernel_size=k))

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs[0]                                          # (batch_size, seq_len, emb_size)

        embeddings = last_hidden_state.permute(0, 2, 1)                         # (batch_size, embed_size, seq_len)
        # ����ÿ��һά�����㣬ʱ�����ػ����õ�(batch_size, num_channels, 1)��tensor
        # ʹ��squeeze����ȥ�����һά��Ȼ����num_channelsά������
        out = torch.cat([self.pool(F.relu(conv(embeddings))).squeeze(-1) for conv in self.convs], dim=1)
                                                                                        # (batch_size, sum(num_channels))
        logits = self.fc(self.dropout(out))
        return logits


 # �������֪����MLP��ģ��
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


class Dataset(torch.utils.data.Dataset):

    def __init__(self, texts, labels, pretrained_model):
        self.labels = [e for e in labels]
        tokenizer = BertTokenizerFast.from_pretrained(pretrained_model)
        self.texts = [tokenizer(text,
                               padding='max_length', max_length=512, truncation=True,
                                return_tensors="pt") for text in texts]

    def __len__(self):
        return len(self.labels)

    def get_batch_labels(self, idx):
        # Fetch a batch of labels
        return np.array(self.labels[idx])

    def get_batch_texts(self, idx):
        # Fetch a batch of inputs
        return self.texts[idx]

    def __getitem__(self, idx):
        batch_texts = self.get_batch_texts(idx)
        batch_y = self.get_batch_labels(idx)
        return batch_texts, batch_y


def train_and_eval(model, train_dataloader, val_dataloaders, learning_rate, epochs, device, use_cuda=True, label_weight=None):
    if label_weight:
        criterion = nn.CrossEntropyLoss(weight=torch.tensor(label_weight))
    else:
        criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)
    if use_cuda:
        model = model.to(device)
        criterion = criterion.to(device)
    if val_dataloaders and len(val_dataloaders) > 0:
        prediction_output = [[] for x in range(len(val_dataloaders))]

    for epoch_num in range(epochs):
        total_acc_train = 0
        total_loss_train = 0
        for train_input, train_label in tqdm(train_dataloader):
            train_label = train_label.type(torch.LongTensor)
            train_label = train_label.to(device)
            mask = train_input['attention_mask'].to(device)
            input_id = train_input['input_ids'].squeeze(1).to(device)
            output = model(input_id, mask)
            batch_loss = criterion(output, train_label)
            total_loss_train += batch_loss.item()
            acc = (output.argmax(dim=1) == train_label).sum().item()
            total_acc_train += acc
            model.zero_grad()
            batch_loss.backward()
            optimizer.step()

        if val_dataloaders and len(val_dataloaders) > 0:

            for v_id, val_dataloader in enumerate(val_dataloaders):
                total_acc_val = 0
                total_loss_val = 0
                total_y_true = []
                total_y_preds = []
                epoch_logits = []
                with torch.no_grad():
                    for val_input, val_label in val_dataloader:
                        val_label = val_label.type(torch.LongTensor)
                        val_label = val_label.to(device)
                        mask = val_input['attention_mask'].to(device)
                        input_id = val_input['input_ids'].squeeze(1).to(device)
                        output = model(input_id, mask)
                        epoch_logits.append(output.detach().cpu().numpy())
                        batch_loss = criterion(output, val_label)

                        total_loss_val += batch_loss.item()
                        acc = (output.argmax(dim=1) == val_label).sum().item()
                        total_acc_val += acc
                        total_y_preds += output.detach().cpu().numpy().tolist()
                        total_y_true += val_label.detach().cpu().numpy().tolist()

                print(
                    f'Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train_dataloader.dataset): .3f} \
                        | Train Accuracy: {total_acc_train / len(train_dataloader.dataset): .3f} \
                        | Val Loss: {total_loss_val / len(val_dataloader.dataset): .3f} \
                        | Val Accuracy: {total_acc_val / len(val_dataloader.dataset): .3f}')
                prediction_output[v_id].append(epoch_logits)

    return prediction_output


def evaluate(model, test_data):
    test = Dataset(test_data)
    test_dataloader = torch.utils.data.DataLoader(test, batch_size=2)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    if use_cuda:
        model = model.cuda()
    total_acc_test = 0
    prediction_output = []
    with torch.no_grad():
        for test_input, test_label in test_dataloader:
            test_label = test_label.type(torch.LongTensor)
            test_label = test_label.to(device)
            mask = test_input['attention_mask'].to(device)
            input_id = test_input['input_ids'].squeeze(1).to(device)
            output = model(input_id, mask)
            acc = (output.argmax(dim=1) == test_label).sum().item()
            total_acc_test += acc
            prediction_output.append(output.detach().cpu().numpy())

    print(f'Test Accuracy: {total_acc_test / len(test_data): .3f}')


if __name__ == '__main__':
    pass
